// Copyright 2022 The Google Research Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "test/common/OperationMapper.h"

#include <algorithm>

#include "src/ArchitectureParams.h"
#include "src/Params.h"
#include "test/common/operation.proto.h"

void MapConvolution(
    const third_party::dnn_accelerator_hls::test::common::ConvolutionOperation
        &operation,
    Params &params, VectorParams &vectorParams) {
  int x0 = std::min<int>(64, operation.ox());
  int x1 = operation.ox() / x0;

  int y0 = std::min<int>(64, operation.oy());
  int y1 = operation.oy() / y0;

  int c1 = operation.ic() / DIMENSION;

  int k1 = operation.oc() / DIMENSION;
  int k2 = 1;

  params.loops[0][0] = y1;
  params.inputYLoopIndex[0] = 0;
  vectorParams.loops[0][0] = y1;
  vectorParams.yLoopIndex[0] = 0;

  params.loops[0][1] = x1;
  params.inputXLoopIndex[0] = 1;
  vectorParams.loops[0][1] = x1;
  vectorParams.xLoopIndex[0] = 1;

  params.loops[0][2] = k2;
  params.weightLoopIndex[0] = 2;
  vectorParams.loops[0][2] = k2;
  vectorParams.kLoopIndex[0] = 2;

  // unused params
  for (int i = 3; i < 6; i++) {
    params.loops[0][i] = 1;
  }

  params.loops[1][0] = c1;
  params.reductionLoopIndex[1] = 0;

  params.loops[1][1] = k1;
  params.weightLoopIndex[1] = 1;
  vectorParams.loops[1][0] = k1;
  vectorParams.kLoopIndex[1] = 0;

  params.loops[1][2] = operation.fy();
  params.fyIndex = 2;

  params.loops[1][3] = operation.fx();
  params.fxIndex = 3;

  params.loops[1][4] = y0;
  params.inputYLoopIndex[1] = 4;
  vectorParams.loops[1][1] = y0;
  vectorParams.yLoopIndex[1] = 1;

  params.loops[1][5] = x0;
  params.inputXLoopIndex[1] = 5;
  vectorParams.loops[1][2] = x0;
  vectorParams.xLoopIndex[1] = 2;

  params.weightReuseIndex[0] = 4;
  params.weightReuseIndex[1] = 5;

  params.STRIDE = operation.stride();
}

void MapMatMul(const third_party::dnn_accelerator_hls::test::common::
                   MatrixMultiplicationOperation &operation,
               Params &params, VectorParams &vectorParams) {
  // Matrix multiplication is a special case of convolution, where ox = ay, ic =
  // ax, oc = bx, and all other dimensions are 1
  third_party::dnn_accelerator_hls::test::common::ConvolutionOperation convOp;

  convOp.set_ox(operation.ay());
  convOp.set_ic(operation.ax());
  convOp.set_oc(operation.bx());
  convOp.set_fx(1);
  convOp.set_fy(1);
  convOp.set_oy(1);
  convOp.set_stride(1);

  MapConvolution(convOp, params, vectorParams);
}

void MapReshapes(
    const third_party::dnn_accelerator_hls::test::common::TensorOperation
        &operation,
    Params &params) {
  if (operation.has_reshapes()) {
    third_party::dnn_accelerator_hls::test::common::Reshapes reshapes =
        operation.reshapes();
    params.CONCAT_HEAD = reshapes.concatenate_heads();
    params.SPLIT_HEAD = reshapes.split_heads();
    params.HEAD_SZ_LG2 = Bits::Log2Ceiling(reshapes.head_size());
    params.TRANSPOSE = reshapes.transpose_matrix_b();
  } else {
    params.CONCAT_HEAD = false;
    params.SPLIT_HEAD = false;
    params.HEAD_SZ_LG2 = 0;
    params.TRANSPOSE = false;
  }
}

void MapTensorOperation(
    const third_party::dnn_accelerator_hls::test::common::TensorOperation
        &operation,
    Params &params, VectorParams &vectorParams) {
  MapReshapes(operation, params);

  if (operation.has_matrix_multiplication()) {
    MapMatMul(operation.matrix_multiplication(), params, vectorParams);
  } else if (operation.has_convolution()) {
    MapConvolution(operation.convolution(), params, vectorParams);
  }

  if (operation.has_activation_function()) {
    params.RELU = operation.activation_function().relu();
    vectorParams.RELU = operation.activation_function().relu();
  } else {
    params.RELU = false;
    vectorParams.RELU = false;
  }

  vectorParams.SCALE = operation.has_scaling();
  vectorParams.SCALE_FACTOR = operation.scaling().scale_factor();
}

void MapSoftmax(
    const third_party::dnn_accelerator_hls::test::common::Softmax &softmaxOp,
    VectorParams &vectorParams) {
  vectorParams.VECTOR_OFFSET = 0;
  vectorParams.OUTPUT_OFFSET = 0;
  vectorParams.SOFTMAX = true;
  vectorParams.LAYER_NORM = false;
  vectorParams.useMatrixProcessorOutput = false;
  vectorParams.SCALE = false;
  vectorParams.SPLIT_HEAD = false;
  vectorParams.RELU = false;
  for (int i = 0; i < 3; i++) {
    vectorParams.loops[0][i] = 1;
  }
  vectorParams.kLoopIndex[0] = 0;
  vectorParams.yLoopIndex[0] = 1;
  vectorParams.xLoopIndex[0] = 2;

  vectorParams.loops[1][0] = 1;
  vectorParams.kLoopIndex[1] = 0;
  vectorParams.loops[1][1] = softmaxOp.tensor_height();
  vectorParams.yLoopIndex[1] = 1;
  vectorParams.loops[1][2] = softmaxOp.tensor_width();
  vectorParams.xLoopIndex[1] = 2;
}

void MapLayerNorm(
    const third_party::dnn_accelerator_hls::test::common::LayerNorm
        &layerNormOp,
    VectorParams &vectorParams) {
  vectorParams.VECTOR_OFFSET = 0;
  vectorParams.OUTPUT_OFFSET = 0;
  vectorParams.SOFTMAX = false;
  vectorParams.LAYER_NORM = true;
  vectorParams.useMatrixProcessorOutput = false;
  vectorParams.SCALE = false;
  vectorParams.SPLIT_HEAD = false;
  vectorParams.RELU = false;
  for (int i = 0; i < 3; i++) {
    vectorParams.loops[0][i] = 1;
  }
  vectorParams.kLoopIndex[0] = 0;
  vectorParams.yLoopIndex[0] = 1;
  vectorParams.xLoopIndex[0] = 2;

  vectorParams.loops[1][0] = 1;
  vectorParams.kLoopIndex[1] = 0;
  vectorParams.loops[1][1] = layerNormOp.tensor_height();
  vectorParams.yLoopIndex[1] = 1;
  vectorParams.loops[1][2] = layerNormOp.tensor_width();
  vectorParams.xLoopIndex[1] = 2;
}

void MapOperationToAccelerator(
    const third_party::dnn_accelerator_hls::test::common::Operation &operation,
    Params &params, VectorParams &vectorParams) {
  if (operation.has_tensor_operation()) {
    MapTensorOperation(operation.tensor_operation(), params, vectorParams);
  } else if (operation.has_softmax()) {
    MapSoftmax(operation.softmax(), vectorParams);
  } else if (operation.has_layer_norm()) {
    MapLayerNorm(operation.layer_norm(), vectorParams);
  } else {
    LOG(FATAL) << "Invalid or unmapped operation. Please double check the "
                  "textproto.";
  }
}
